{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import time\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F\n",
    "import d2lzh as d2l \n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def nin_block(in_channels,out_channels,kernel_size,stride,padding):\n",
    "    blk = nn.Sequential(\n",
    "        nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(out_channels,out_channels,kernel_size=1),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(out_channels,out_channels,kernel_size=1),\n",
    "        nn.ReLU()\n",
    "        )\n",
    "    return blk \n",
    "\n",
    "class GlobalAvgPool2d(nn.Module):\n",
    "    # 全局平均池化层可以将池化窗口设置为输入的宽和高实现\n",
    "    def __init__(self):\n",
    "        super(GlobalAvgPool2d,self).__init__()\n",
    "    def forward(self,x):\n",
    "        return F.avg_pool2d(x,kernel_size=x.size()[2:])\n",
    "\n",
    "net = nn.Sequential(\n",
    "    nin_block(1,96,kernel_size=11,stride=4,padding=0),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2),\n",
    "    nin_block(96,256,kernel_size=5,stride=1,padding=2),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2),\n",
    "    nin_block(256,384,kernel_size=3,stride=1,padding=1),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2),\n",
    "    nn.Dropout(0.5),\n",
    "    nin_block(384,10,kernel_size=3,stride=1,padding=1),\n",
    "    GlobalAvgPool2d(),\n",
    "    d2l.FlattenLayer()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "0 output shape: torch.Size([1, 96, 54, 54])\n1 output shape: torch.Size([1, 96, 26, 26])\n2 output shape: torch.Size([1, 256, 26, 26])\n3 output shape: torch.Size([1, 256, 12, 12])\n4 output shape: torch.Size([1, 384, 12, 12])\n5 output shape: torch.Size([1, 384, 5, 5])\n6 output shape: torch.Size([1, 384, 5, 5])\n7 output shape: torch.Size([1, 10, 5, 5])\n8 output shape: torch.Size([1, 10, 1, 1])\n9 output shape: torch.Size([1, 10])\n"
    }
   ],
   "source": [
    "x = torch.rand(1,1,224,224)\n",
    "for name,blk in net.named_children():\n",
    "    x = blk(x)\n",
    "    print(name,'output shape:',x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 1.0409, train acc 0.619,test acc 0.781,time 108.8\nepoch 2, loss 0.2574, train acc 0.813,test acc 0.822,time 105.2\nepoch 3, loss 0.1459, train acc 0.840,test acc 0.844,time 105.6\nepoch 4, loss 0.1010, train acc 0.850,test acc 0.851,time 105.8\nepoch 5, loss 0.0760, train acc 0.858,test acc 0.853,time 105.8\n"
    }
   ],
   "source": [
    "batch_size = 64\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)\n",
    "\n",
    "lr,num_epochs = 0.002,5\n",
    "optimizer = torch.optim.Adam(net.parameters(),lr=lr)\n",
    "d2l.train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}